{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Encoder-Decoder Architecture"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The encoder-decoder architecture is a neural network design pattern. In this architecture, the network is\n",
    "partitioned into two parts, the encoder and the decoder. The encoder’s role is encoding the inputs into state,\n",
    "which often contains several tensors. Then the state is passed into the decoder to generate the outputs. In\n",
    "machine translation, the encoder transforms a source sentence, e.g. “Hello world.”, into state, e.g. a vector,\n",
    "that captures its semantic information. The decoder then uses this state to generate the translated target\n",
    "sentence, e.g. “Bonjour le monde.”."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg height=\"34pt\" version=\"1.1\" viewBox=\"0 0 304 34\" width=\"304pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<defs>\n",
       "<g>\n",
       "<symbol id=\"glyph0-0\" overflow=\"visible\">\n",
       "<path d=\"M 1.125 0 L 1.125 -5.625 L 5.625 -5.625 L 5.625 0 Z M 1.265625 -0.140625 L 5.484375 -0.140625 L 5.484375 -5.484375 L 1.265625 -5.484375 Z M 1.265625 -0.140625 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-1\" overflow=\"visible\">\n",
       "<path d=\"M 0.71875 0 L 0.71875 -6.4375 L 5.375 -6.4375 L 5.375 -5.6875 L 1.5625 -5.6875 L 1.5625 -3.703125 L 5.125 -3.703125 L 5.125 -2.953125 L 1.5625 -2.953125 L 1.5625 -0.765625 L 5.515625 -0.765625 L 5.515625 0 Z M 0.71875 0 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-2\" overflow=\"visible\">\n",
       "<path d=\"M 0.59375 0 L 0.59375 -4.671875 L 1.3125 -4.671875 L 1.3125 -4 C 1.644531 -4.507812 2.140625 -4.765625 2.796875 -4.765625 C 3.078125 -4.765625 3.332031 -4.710938 3.5625 -4.609375 C 3.800781 -4.515625 3.976562 -4.382812 4.09375 -4.21875 C 4.207031 -4.0625 4.289062 -3.867188 4.34375 -3.640625 C 4.375 -3.492188 4.390625 -3.238281 4.390625 -2.875 L 4.390625 0 L 3.59375 0 L 3.59375 -2.84375 C 3.59375 -3.164062 3.5625 -3.40625 3.5 -3.5625 C 3.4375 -3.71875 3.328125 -3.84375 3.171875 -3.9375 C 3.015625 -4.039062 2.832031 -4.09375 2.625 -4.09375 C 2.289062 -4.09375 2 -3.984375 1.75 -3.765625 C 1.507812 -3.554688 1.390625 -3.148438 1.390625 -2.546875 L 1.390625 0 Z M 0.59375 0 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-3\" overflow=\"visible\">\n",
       "<path d=\"M 3.640625 -1.703125 L 4.421875 -1.609375 C 4.335938 -1.066406 4.117188 -0.644531 3.765625 -0.34375 C 3.410156 -0.0390625 2.976562 0.109375 2.46875 0.109375 C 1.832031 0.109375 1.320312 -0.0976562 0.9375 -0.515625 C 0.550781 -0.929688 0.359375 -1.53125 0.359375 -2.3125 C 0.359375 -2.820312 0.441406 -3.265625 0.609375 -3.640625 C 0.773438 -4.015625 1.023438 -4.296875 1.359375 -4.484375 C 1.703125 -4.671875 2.078125 -4.765625 2.484375 -4.765625 C 2.984375 -4.765625 3.394531 -4.632812 3.71875 -4.375 C 4.039062 -4.125 4.25 -3.765625 4.34375 -3.296875 L 3.578125 -3.171875 C 3.503906 -3.484375 3.375 -3.71875 3.1875 -3.875 C 3 -4.039062 2.773438 -4.125 2.515625 -4.125 C 2.109375 -4.125 1.78125 -3.976562 1.53125 -3.6875 C 1.289062 -3.40625 1.171875 -2.957031 1.171875 -2.34375 C 1.171875 -1.707031 1.289062 -1.25 1.53125 -0.96875 C 1.769531 -0.6875 2.082031 -0.546875 2.46875 -0.546875 C 2.78125 -0.546875 3.039062 -0.640625 3.25 -0.828125 C 3.457031 -1.015625 3.585938 -1.304688 3.640625 -1.703125 Z M 3.640625 -1.703125 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-4\" overflow=\"visible\">\n",
       "<path d=\"M 0.296875 -2.328125 C 0.296875 -3.191406 0.535156 -3.832031 1.015625 -4.25 C 1.421875 -4.59375 1.910156 -4.765625 2.484375 -4.765625 C 3.128906 -4.765625 3.65625 -4.554688 4.0625 -4.140625 C 4.46875 -3.722656 4.671875 -3.144531 4.671875 -2.40625 C 4.671875 -1.800781 4.578125 -1.328125 4.390625 -0.984375 C 4.210938 -0.640625 3.953125 -0.367188 3.609375 -0.171875 C 3.265625 0.015625 2.890625 0.109375 2.484375 0.109375 C 1.828125 0.109375 1.296875 -0.0976562 0.890625 -0.515625 C 0.492188 -0.941406 0.296875 -1.546875 0.296875 -2.328125 Z M 1.109375 -2.328125 C 1.109375 -1.734375 1.238281 -1.285156 1.5 -0.984375 C 1.757812 -0.691406 2.085938 -0.546875 2.484375 -0.546875 C 2.878906 -0.546875 3.207031 -0.691406 3.46875 -0.984375 C 3.726562 -1.285156 3.859375 -1.742188 3.859375 -2.359375 C 3.859375 -2.929688 3.726562 -3.367188 3.46875 -3.671875 C 3.207031 -3.972656 2.878906 -4.125 2.484375 -4.125 C 2.085938 -4.125 1.757812 -3.972656 1.5 -3.671875 C 1.238281 -3.378906 1.109375 -2.929688 1.109375 -2.328125 Z M 1.109375 -2.328125 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-5\" overflow=\"visible\">\n",
       "<path d=\"M 3.625 0 L 3.625 -0.59375 C 3.320312 -0.125 2.882812 0.109375 2.3125 0.109375 C 1.945312 0.109375 1.609375 0.00390625 1.296875 -0.203125 C 0.984375 -0.410156 0.738281 -0.695312 0.5625 -1.0625 C 0.394531 -1.425781 0.3125 -1.847656 0.3125 -2.328125 C 0.3125 -2.796875 0.390625 -3.21875 0.546875 -3.59375 C 0.703125 -3.976562 0.929688 -4.269531 1.234375 -4.46875 C 1.546875 -4.664062 1.894531 -4.765625 2.28125 -4.765625 C 2.5625 -4.765625 2.8125 -4.707031 3.03125 -4.59375 C 3.25 -4.476562 3.425781 -4.320312 3.5625 -4.125 L 3.5625 -6.4375 L 4.359375 -6.4375 L 4.359375 0 Z M 1.125 -2.328125 C 1.125 -1.734375 1.25 -1.285156 1.5 -0.984375 C 1.75 -0.691406 2.046875 -0.546875 2.390625 -0.546875 C 2.734375 -0.546875 3.023438 -0.6875 3.265625 -0.96875 C 3.515625 -1.25 3.640625 -1.679688 3.640625 -2.265625 C 3.640625 -2.898438 3.515625 -3.367188 3.265625 -3.671875 C 3.015625 -3.972656 2.710938 -4.125 2.359375 -4.125 C 2.003906 -4.125 1.707031 -3.976562 1.46875 -3.6875 C 1.238281 -3.394531 1.125 -2.941406 1.125 -2.328125 Z M 1.125 -2.328125 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-6\" overflow=\"visible\">\n",
       "<path d=\"M 3.78125 -1.5 L 4.609375 -1.40625 C 4.472656 -0.925781 4.226562 -0.550781 3.875 -0.28125 C 3.53125 -0.0195312 3.085938 0.109375 2.546875 0.109375 C 1.867188 0.109375 1.328125 -0.0976562 0.921875 -0.515625 C 0.523438 -0.941406 0.328125 -1.535156 0.328125 -2.296875 C 0.328125 -3.078125 0.53125 -3.679688 0.9375 -4.109375 C 1.34375 -4.546875 1.867188 -4.765625 2.515625 -4.765625 C 3.140625 -4.765625 3.644531 -4.550781 4.03125 -4.125 C 4.425781 -3.707031 4.625 -3.113281 4.625 -2.34375 C 4.625 -2.289062 4.625 -2.21875 4.625 -2.125 L 1.140625 -2.125 C 1.171875 -1.613281 1.316406 -1.222656 1.578125 -0.953125 C 1.835938 -0.679688 2.164062 -0.546875 2.5625 -0.546875 C 2.851562 -0.546875 3.097656 -0.617188 3.296875 -0.765625 C 3.503906 -0.921875 3.664062 -1.164062 3.78125 -1.5 Z M 1.1875 -2.78125 L 3.796875 -2.78125 C 3.765625 -3.175781 3.664062 -3.472656 3.5 -3.671875 C 3.25 -3.972656 2.921875 -4.125 2.515625 -4.125 C 2.148438 -4.125 1.84375 -4 1.59375 -3.75 C 1.351562 -3.507812 1.21875 -3.1875 1.1875 -2.78125 Z M 1.1875 -2.78125 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-7\" overflow=\"visible\">\n",
       "<path d=\"M 0.578125 0 L 0.578125 -4.671875 L 1.296875 -4.671875 L 1.296875 -3.953125 C 1.472656 -4.285156 1.640625 -4.503906 1.796875 -4.609375 C 1.953125 -4.710938 2.125 -4.765625 2.3125 -4.765625 C 2.570312 -4.765625 2.84375 -4.679688 3.125 -4.515625 L 2.84375 -3.78125 C 2.65625 -3.894531 2.460938 -3.953125 2.265625 -3.953125 C 2.097656 -3.953125 1.941406 -3.898438 1.796875 -3.796875 C 1.660156 -3.691406 1.5625 -3.546875 1.5 -3.359375 C 1.414062 -3.078125 1.375 -2.769531 1.375 -2.4375 L 1.375 0 Z M 0.578125 0 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-8\" overflow=\"visible\">\n",
       "<path d=\"M 0.6875 0 L 0.6875 -6.4375 L 2.90625 -6.4375 C 3.414062 -6.4375 3.800781 -6.40625 4.0625 -6.34375 C 4.425781 -6.257812 4.738281 -6.109375 5 -5.890625 C 5.34375 -5.597656 5.597656 -5.226562 5.765625 -4.78125 C 5.929688 -4.34375 6.015625 -3.832031 6.015625 -3.25 C 6.015625 -2.757812 5.957031 -2.328125 5.84375 -1.953125 C 5.726562 -1.578125 5.582031 -1.265625 5.40625 -1.015625 C 5.226562 -0.765625 5.03125 -0.566406 4.8125 -0.421875 C 4.601562 -0.285156 4.347656 -0.179688 4.046875 -0.109375 C 3.753906 -0.0351562 3.410156 0 3.015625 0 Z M 1.546875 -0.765625 L 2.921875 -0.765625 C 3.347656 -0.765625 3.679688 -0.800781 3.921875 -0.875 C 4.160156 -0.957031 4.351562 -1.070312 4.5 -1.21875 C 4.695312 -1.414062 4.851562 -1.6875 4.96875 -2.03125 C 5.082031 -2.375 5.140625 -2.785156 5.140625 -3.265625 C 5.140625 -3.941406 5.03125 -4.457031 4.8125 -4.8125 C 4.59375 -5.175781 4.320312 -5.421875 4 -5.546875 C 3.769531 -5.640625 3.40625 -5.6875 2.90625 -5.6875 L 1.546875 -5.6875 Z M 1.546875 -0.765625 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-9\" overflow=\"visible\">\n",
       "<path d=\"M 0.84375 0 L 0.84375 -6.4375 L 1.6875 -6.4375 L 1.6875 0 Z M 0.84375 0 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-10\" overflow=\"visible\">\n",
       "<path d=\"M 0.59375 1.78125 L 0.59375 -4.671875 L 1.3125 -4.671875 L 1.3125 -4.0625 C 1.476562 -4.300781 1.664062 -4.476562 1.875 -4.59375 C 2.09375 -4.707031 2.359375 -4.765625 2.671875 -4.765625 C 3.066406 -4.765625 3.414062 -4.660156 3.71875 -4.453125 C 4.019531 -4.253906 4.25 -3.96875 4.40625 -3.59375 C 4.5625 -3.21875 4.640625 -2.8125 4.640625 -2.375 C 4.640625 -1.894531 4.550781 -1.460938 4.375 -1.078125 C 4.207031 -0.691406 3.960938 -0.394531 3.640625 -0.1875 C 3.316406 0.0078125 2.972656 0.109375 2.609375 0.109375 C 2.347656 0.109375 2.113281 0.0507812 1.90625 -0.0625 C 1.695312 -0.175781 1.523438 -0.316406 1.390625 -0.484375 L 1.390625 1.78125 Z M 1.3125 -2.3125 C 1.3125 -1.707031 1.429688 -1.257812 1.671875 -0.96875 C 1.921875 -0.6875 2.21875 -0.546875 2.5625 -0.546875 C 2.90625 -0.546875 3.203125 -0.691406 3.453125 -0.984375 C 3.710938 -1.285156 3.84375 -1.75 3.84375 -2.375 C 3.84375 -2.96875 3.71875 -3.410156 3.46875 -3.703125 C 3.226562 -4.003906 2.9375 -4.15625 2.59375 -4.15625 C 2.257812 -4.15625 1.960938 -3.992188 1.703125 -3.671875 C 1.441406 -3.359375 1.3125 -2.90625 1.3125 -2.3125 Z M 1.3125 -2.3125 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-11\" overflow=\"visible\">\n",
       "<path d=\"M 3.65625 0 L 3.65625 -0.6875 C 3.289062 -0.15625 2.796875 0.109375 2.171875 0.109375 C 1.898438 0.109375 1.644531 0.0546875 1.40625 -0.046875 C 1.164062 -0.160156 0.984375 -0.296875 0.859375 -0.453125 C 0.742188 -0.609375 0.664062 -0.800781 0.625 -1.03125 C 0.59375 -1.1875 0.578125 -1.4375 0.578125 -1.78125 L 0.578125 -4.671875 L 1.359375 -4.671875 L 1.359375 -2.078125 C 1.359375 -1.660156 1.378906 -1.382812 1.421875 -1.25 C 1.460938 -1.039062 1.5625 -0.875 1.71875 -0.75 C 1.882812 -0.632812 2.085938 -0.578125 2.328125 -0.578125 C 2.566406 -0.578125 2.789062 -0.632812 3 -0.75 C 3.207031 -0.875 3.351562 -1.039062 3.4375 -1.25 C 3.519531 -1.457031 3.5625 -1.765625 3.5625 -2.171875 L 3.5625 -4.671875 L 4.359375 -4.671875 L 4.359375 0 Z M 3.65625 0 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-12\" overflow=\"visible\">\n",
       "<path d=\"M 2.328125 -0.703125 L 2.4375 -0.015625 C 2.207031 0.0351562 2.007812 0.0625 1.84375 0.0625 C 1.550781 0.0625 1.328125 0.015625 1.171875 -0.078125 C 1.015625 -0.171875 0.898438 -0.289062 0.828125 -0.4375 C 0.765625 -0.582031 0.734375 -0.890625 0.734375 -1.359375 L 0.734375 -4.046875 L 0.15625 -4.046875 L 0.15625 -4.671875 L 0.734375 -4.671875 L 0.734375 -5.828125 L 1.53125 -6.296875 L 1.53125 -4.671875 L 2.328125 -4.671875 L 2.328125 -4.046875 L 1.53125 -4.046875 L 1.53125 -1.328125 C 1.53125 -1.097656 1.539062 -0.953125 1.5625 -0.890625 C 1.59375 -0.828125 1.640625 -0.773438 1.703125 -0.734375 C 1.765625 -0.691406 1.851562 -0.671875 1.96875 -0.671875 C 2.0625 -0.671875 2.179688 -0.679688 2.328125 -0.703125 Z M 2.328125 -0.703125 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-13\" overflow=\"visible\">\n",
       "<path d=\"M 0.40625 -2.0625 L 1.203125 -2.140625 C 1.242188 -1.816406 1.332031 -1.550781 1.46875 -1.34375 C 1.613281 -1.132812 1.832031 -0.96875 2.125 -0.84375 C 2.414062 -0.71875 2.742188 -0.65625 3.109375 -0.65625 C 3.429688 -0.65625 3.71875 -0.703125 3.96875 -0.796875 C 4.21875 -0.890625 4.40625 -1.019531 4.53125 -1.1875 C 4.65625 -1.363281 4.71875 -1.550781 4.71875 -1.75 C 4.71875 -1.945312 4.65625 -2.117188 4.53125 -2.265625 C 4.414062 -2.421875 4.222656 -2.550781 3.953125 -2.65625 C 3.785156 -2.726562 3.40625 -2.832031 2.8125 -2.96875 C 2.21875 -3.113281 1.800781 -3.25 1.5625 -3.375 C 1.257812 -3.53125 1.03125 -3.726562 0.875 -3.96875 C 0.726562 -4.207031 0.65625 -4.476562 0.65625 -4.78125 C 0.65625 -5.101562 0.742188 -5.40625 0.921875 -5.6875 C 1.109375 -5.96875 1.378906 -6.179688 1.734375 -6.328125 C 2.085938 -6.472656 2.484375 -6.546875 2.921875 -6.546875 C 3.398438 -6.546875 3.820312 -6.46875 4.1875 -6.3125 C 4.550781 -6.164062 4.828125 -5.941406 5.015625 -5.640625 C 5.210938 -5.335938 5.320312 -5 5.34375 -4.625 L 4.515625 -4.5625 C 4.472656 -4.96875 4.328125 -5.273438 4.078125 -5.484375 C 3.828125 -5.691406 3.453125 -5.796875 2.953125 -5.796875 C 2.441406 -5.796875 2.066406 -5.703125 1.828125 -5.515625 C 1.585938 -5.328125 1.46875 -5.097656 1.46875 -4.828125 C 1.46875 -4.597656 1.550781 -4.410156 1.71875 -4.265625 C 1.882812 -4.109375 2.3125 -3.953125 3 -3.796875 C 3.695312 -3.640625 4.175781 -3.503906 4.4375 -3.390625 C 4.8125 -3.222656 5.085938 -3.003906 5.265625 -2.734375 C 5.441406 -2.472656 5.53125 -2.164062 5.53125 -1.8125 C 5.53125 -1.476562 5.429688 -1.15625 5.234375 -0.84375 C 5.035156 -0.539062 4.753906 -0.304688 4.390625 -0.140625 C 4.023438 0.0234375 3.613281 0.109375 3.15625 0.109375 C 2.570312 0.109375 2.082031 0.0234375 1.6875 -0.140625 C 1.289062 -0.316406 0.976562 -0.570312 0.75 -0.90625 C 0.53125 -1.25 0.414062 -1.632812 0.40625 -2.0625 Z M 0.40625 -2.0625 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-14\" overflow=\"visible\">\n",
       "<path d=\"M 3.640625 -0.578125 C 3.347656 -0.328125 3.066406 -0.148438 2.796875 -0.046875 C 2.523438 0.0546875 2.234375 0.109375 1.921875 0.109375 C 1.410156 0.109375 1.015625 -0.015625 0.734375 -0.265625 C 0.460938 -0.515625 0.328125 -0.835938 0.328125 -1.234375 C 0.328125 -1.460938 0.378906 -1.671875 0.484375 -1.859375 C 0.585938 -2.046875 0.722656 -2.195312 0.890625 -2.3125 C 1.054688 -2.425781 1.242188 -2.515625 1.453125 -2.578125 C 1.609375 -2.609375 1.84375 -2.644531 2.15625 -2.6875 C 2.800781 -2.757812 3.273438 -2.851562 3.578125 -2.96875 C 3.578125 -3.070312 3.578125 -3.140625 3.578125 -3.171875 C 3.578125 -3.492188 3.503906 -3.71875 3.359375 -3.84375 C 3.148438 -4.03125 2.847656 -4.125 2.453125 -4.125 C 2.078125 -4.125 1.800781 -4.054688 1.625 -3.921875 C 1.445312 -3.796875 1.316406 -3.566406 1.234375 -3.234375 L 0.46875 -3.328125 C 0.53125 -3.660156 0.640625 -3.925781 0.796875 -4.125 C 0.960938 -4.332031 1.195312 -4.488281 1.5 -4.59375 C 1.8125 -4.707031 2.164062 -4.765625 2.5625 -4.765625 C 2.96875 -4.765625 3.289062 -4.71875 3.53125 -4.625 C 3.78125 -4.53125 3.960938 -4.410156 4.078125 -4.265625 C 4.203125 -4.128906 4.285156 -3.953125 4.328125 -3.734375 C 4.359375 -3.597656 4.375 -3.359375 4.375 -3.015625 L 4.375 -1.953125 C 4.375 -1.222656 4.390625 -0.757812 4.421875 -0.5625 C 4.453125 -0.363281 4.519531 -0.175781 4.625 0 L 3.796875 0 C 3.710938 -0.164062 3.660156 -0.359375 3.640625 -0.578125 Z M 3.578125 -2.34375 C 3.285156 -2.226562 2.851562 -2.128906 2.28125 -2.046875 C 1.957031 -1.992188 1.726562 -1.9375 1.59375 -1.875 C 1.457031 -1.820312 1.351562 -1.738281 1.28125 -1.625 C 1.207031 -1.507812 1.171875 -1.382812 1.171875 -1.25 C 1.171875 -1.039062 1.25 -0.863281 1.40625 -0.71875 C 1.5625 -0.582031 1.796875 -0.515625 2.109375 -0.515625 C 2.410156 -0.515625 2.679688 -0.582031 2.921875 -0.71875 C 3.160156 -0.851562 3.335938 -1.035156 3.453125 -1.265625 C 3.535156 -1.441406 3.578125 -1.703125 3.578125 -2.046875 Z M 3.578125 -2.34375 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "<symbol id=\"glyph0-15\" overflow=\"visible\">\n",
       "<path d=\"M 0.4375 -3.140625 C 0.4375 -4.203125 0.722656 -5.035156 1.296875 -5.640625 C 1.867188 -6.253906 2.609375 -6.5625 3.515625 -6.5625 C 4.109375 -6.5625 4.644531 -6.414062 5.125 -6.125 C 5.601562 -5.84375 5.96875 -5.445312 6.21875 -4.9375 C 6.46875 -4.425781 6.59375 -3.851562 6.59375 -3.21875 C 6.59375 -2.5625 6.460938 -1.972656 6.203125 -1.453125 C 5.941406 -0.941406 5.566406 -0.550781 5.078125 -0.28125 C 4.597656 -0.0195312 4.078125 0.109375 3.515625 0.109375 C 2.910156 0.109375 2.367188 -0.0351562 1.890625 -0.328125 C 1.410156 -0.617188 1.046875 -1.019531 0.796875 -1.53125 C 0.554688 -2.039062 0.4375 -2.578125 0.4375 -3.140625 Z M 1.3125 -3.125 C 1.3125 -2.34375 1.519531 -1.726562 1.9375 -1.28125 C 2.351562 -0.84375 2.878906 -0.625 3.515625 -0.625 C 4.148438 -0.625 4.675781 -0.847656 5.09375 -1.296875 C 5.507812 -1.742188 5.71875 -2.382812 5.71875 -3.21875 C 5.71875 -3.738281 5.628906 -4.191406 5.453125 -4.578125 C 5.273438 -4.972656 5.015625 -5.28125 4.671875 -5.5 C 4.328125 -5.71875 3.945312 -5.828125 3.53125 -5.828125 C 2.925781 -5.828125 2.40625 -5.617188 1.96875 -5.203125 C 1.53125 -4.785156 1.3125 -4.09375 1.3125 -3.125 Z M 1.3125 -3.125 \" style=\"stroke:none;\"/>\n",
       "</symbol>\n",
       "</g>\n",
       "</defs>\n",
       "<g id=\"surface1\">\n",
       "<path d=\"M 123.75 59 L 175.148438 59 L 175.148438 91.199219 L 123.75 91.199219 Z M 123.75 59 \" style=\"fill-rule:nonzero;fill:rgb(69.804382%,85.098267%,100%);fill-opacity:1;stroke-width:1;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"70.6893\" xlink:href=\"#glyph0-1\" y=\"19.6\"/>\n",
       "  <use x=\"76.6923\" xlink:href=\"#glyph0-2\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"81.6981\" xlink:href=\"#glyph0-3\" y=\"19.6\"/>\n",
       "  <use x=\"86.1981\" xlink:href=\"#glyph0-4\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"91.2039\" xlink:href=\"#glyph0-5\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"96.2097\" xlink:href=\"#glyph0-6\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"101.2155\" xlink:href=\"#glyph0-7\" y=\"19.6\"/>\n",
       "</g>\n",
       "<path d=\"M 253.050781 59 L 304.449219 59 L 304.449219 91.199219 L 253.050781 91.199219 Z M 253.050781 59 \" style=\"fill-rule:nonzero;fill:rgb(69.804382%,85.098267%,100%);fill-opacity:1;stroke-width:1;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"199.741\" xlink:href=\"#glyph0-8\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"206.2408\" xlink:href=\"#glyph0-6\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"211.2466\" xlink:href=\"#glyph0-3\" y=\"19.6\"/>\n",
       "  <use x=\"215.7466\" xlink:href=\"#glyph0-4\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"220.7524\" xlink:href=\"#glyph0-5\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"225.7582\" xlink:href=\"#glyph0-6\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"230.764\" xlink:href=\"#glyph0-7\" y=\"19.6\"/>\n",
       "</g>\n",
       "<path d=\"M 175.148438 75.101562 L 186.800781 75.101562 \" style=\"fill:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<path d=\"M 190.800781 75.101562 L 186.800781 75.101562 M 186.800781 73.601562 L 190.800781 75.101562 L 186.800781 76.601562 \" style=\"fill:none;stroke-width:1;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<path d=\"M 106.199219 75.101562 L 117.851562 75.101562 \" style=\"fill:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<path d=\"M 121.851562 75.101562 L 117.851562 75.101562 M 117.851562 73.601562 L 121.851562 75.101562 L 117.851562 76.601562 \" style=\"fill:none;stroke-width:1;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<path d=\"M 304.449219 75.101562 L 316.101562 75.101562 \" style=\"fill:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<path d=\"M 320.101562 75.101562 L 316.101562 75.101562 M 316.101562 73.601562 L 320.101562 75.101562 L 316.101562 76.601562 \" style=\"fill:none;stroke-width:1;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<path d=\"M 63.398438 64.601562 L 106.199219 64.601562 L 106.199219 85.601562 L 63.398438 85.601562 Z M 63.398438 64.601562 \" style=\"fill-rule:nonzero;fill:rgb(100%,100%,100%);fill-opacity:1;stroke-width:1;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"12.79146\" xlink:href=\"#glyph0-9\" y=\"19.6\"/>\n",
       "  <use x=\"15.29166\" xlink:href=\"#glyph0-2\" y=\"19.6\"/>\n",
       "  <use x=\"20.29386\" xlink:href=\"#glyph0-10\" y=\"19.6\"/>\n",
       "  <use x=\"25.29606\" xlink:href=\"#glyph0-11\" y=\"19.6\"/>\n",
       "  <use x=\"30.29826\" xlink:href=\"#glyph0-12\" y=\"19.6\"/>\n",
       "</g>\n",
       "<path d=\"M 192.699219 64.601562 L 235.5 64.601562 L 235.5 85.601562 L 192.699219 85.601562 Z M 192.699219 64.601562 \" style=\"fill-rule:nonzero;fill:rgb(100%,100%,100%);fill-opacity:1;stroke-width:1;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"141.5927\" xlink:href=\"#glyph0-13\" y=\"19.6\"/>\n",
       "  <use x=\"147.5957\" xlink:href=\"#glyph0-12\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"150.0959\" xlink:href=\"#glyph0-14\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"155.1017\" xlink:href=\"#glyph0-12\" y=\"19.6\"/>\n",
       "</g>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"157.6019\" xlink:href=\"#glyph0-6\" y=\"19.6\"/>\n",
       "</g>\n",
       "<path d=\"M 322 64.601562 L 364.800781 64.601562 L 364.800781 85.601562 L 322 85.601562 Z M 322 64.601562 \" style=\"fill-rule:nonzero;fill:rgb(100%,100%,100%);fill-opacity:1;stroke-width:1;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<g style=\"fill:rgb(0%,0%,0%);fill-opacity:1;\">\n",
       "  <use x=\"267.8912\" xlink:href=\"#glyph0-15\" y=\"19.6\"/>\n",
       "  <use x=\"274.8914\" xlink:href=\"#glyph0-11\" y=\"19.6\"/>\n",
       "  <use x=\"279.8936\" xlink:href=\"#glyph0-12\" y=\"19.6\"/>\n",
       "  <use x=\"282.3938\" xlink:href=\"#glyph0-10\" y=\"19.6\"/>\n",
       "  <use x=\"287.396\" xlink:href=\"#glyph0-11\" y=\"19.6\"/>\n",
       "  <use x=\"292.3982\" xlink:href=\"#glyph0-12\" y=\"19.6\"/>\n",
       "</g>\n",
       "<path d=\"M 235.5 75.101562 L 247.148438 75.101562 \" style=\"fill:none;stroke-width:1;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "<path d=\"M 251.148438 75.101562 L 247.148438 75.101562 M 247.148438 73.601562 L 251.148438 75.101562 L 247.148438 76.601562 \" style=\"fill:none;stroke-width:1;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;\" transform=\"matrix(1,0,0,1,-62,-58)\"/>\n",
       "</g>\n",
       "</svg>"
      ],
      "text/plain": [
       "<IPython.core.display.SVG object>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import SVG\n",
    "SVG('img/encoder-decoder.svg')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we will show an interface to implement this encoder-decoder architecture."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Encoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The encoder is a normal neural network that takes inputs, e.g. a source sentence, to return outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(Encoder, self).__init__(**kwargs)\n",
    "\n",
    "    def forward(self, X, *args):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The decoder has an additional method init_state to parse the outputs of the encoder with possible ad-\n",
    "ditional information, e.g. the valid lengths of inputs, to return the state it needs. In the forward method,\n",
    "the decoder takes both inputs, e.g. a target sentence, and the state. It returns outputs, with potentially\n",
    "modified state if the encoder contains RNN layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(Decoder, self).__init__(**kwargs)\n",
    "\n",
    "    def init_state(self, enc_outputs, *args):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def forward(self, X, state):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The encoder-decoder model contains both an encoder an decoder. We implement its forward method for\n",
    "training. It takes both encoder inputs and decoder inputs, with optional additional information. During\n",
    "computation, it first compute encoder outputs to initialize the decoder state, and then returns the decoder\n",
    "outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderDecoder(nn.Module):\n",
    "    def __init__(self, encoder, decoder, **kwargs):\n",
    "        super(EncoderDecoder, self).__init__(**kwargs)\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def forward(self, enc_X, dec_X, *args):\n",
    "        enc_outputs = self.encoder(enc_X, *args)\n",
    "        dec_state = self.decoder.init_state(enc_outputs, *args)\n",
    "        return self.decoder(dec_X, dec_state)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
